Skip to content

feat(nvidia): add ntops rms norm backend#616

Draft
voltjia wants to merge 1 commit into
masterfrom
feat/nvidia-ntops-rms-norm
Draft

feat(nvidia): add ntops rms norm backend#616
voltjia wants to merge 1 commit into
masterfrom
feat/nvidia-ntops-rms-norm

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 20, 2026

@/tmp/pr616-body.md

Comment thread CMakeLists.txt Outdated
option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF)
set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation")
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分主要是用来写 option 的,请把下面这堆 set 给挪到一个专门的 section。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。WITH_NINETOOTHED 仍然放在 option 区,下面这些 cache 变量已经挪到单独的 NineToothed code generation configuration section 里。

Comment thread scripts/generate_ninetoothed_ops.py Outdated
_SUPPORTED_OPS = ("rms_norm",)


def _import_ninetoothed(source_dir):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥不直接是 import ninetoothed?且请不要使用缩写,直接使用全称 ninetoothed

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。现在实现移到了 src/native/ninetoothed/codegen.py_import_ninetoothed 只在可选 source dir 需要时调整 sys.path,随后直接 import ninetoothed,变量名也不再用 nt 缩写。

Comment thread scripts/generate_ninetoothed_ops.py Outdated
return nt


def _import_ntops():
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥不直接是 import ntops

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。_rms_norm_premake 里直接 import ntops,不再通过 importlib_import_ntops 包一层。

Comment thread scripts/generate_ninetoothed_ops.py Outdated

_DEFAULT_DTYPES = ("float32", "float16", "bfloat16")

_DEFAULT_RMS_NORM_SHAPES = (
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

具体算子相关的内容,应该放到 src 里,scripts 里面只放纯功能性工具或者构建相关脚本。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。scripts/generate_ninetoothed_ops.py 现在只作为构建入口,把 src 加到 sys.path 后委托给 native.ninetoothed.codegen.main();具体算子和生成逻辑放到了 src/native/ninetoothed/codegen.py

Comment thread scripts/generate_ninetoothed_ops.py Outdated
return importlib.import_module("ntops")


def _rms_norm_premake_rank2(dim0, dim1, dtype, block_size):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这部分应当放在 src 中合适的地方,而不是在 scripts 下。以下同类问题不再赘述,但请一并修改。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已一并修改。RmsNorm 的 premake 包装、rank/dtype config、manifest 生成都挪到了 src/native/ninetoothed/codegen.pyscripts 下不再放算子细节。

Comment thread scripts/generate_ninetoothed_ops.py Outdated
return arrangement, application, tensors


def _rms_norm_premake_rank3(dim0, dim1, dim2, dtype, block_size):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这为啥要分 rank?不是只有 shape 不一样,那不是传个 shape 就行了嘛?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。之前按 rank 拆函数是为了让 ninetoothed.build 生成不同 launcher 参数;现在改成使用 ntops 自带的动态-rank premake,只按 ndim/dtype 生成同一个 infiniops_ninetoothed_rms_norm dispatcher,Python 侧不再拆 rank2/rank3 premake。

Comment thread scripts/generate_ninetoothed_ops.py Outdated
return arrangement, application, tensors


def _parse_shape(value):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数是干嘛的?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除。现在不再按具体 shape 编译,也不需要解析 1x64 这类 shape 字符串;配置改为 INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS / --rms-norm-ndims


namespace detail {

inline int NineToothedRmsNormDTypeIndex(DataType dtype) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种类似的 helper 不是应该是整个九齿 common 的嘛?不要放在 rms_norm 下面。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改。DTypeIndexSizeArgFromTensorFromScalar 都放到 src/native/ninetoothed/tensor.h 作为九齿 common helper;rms_norm/ninetoothed.h 只保留 ExpandedRmsNormWeight 这种算子特有适配和 generated launcher 调用。

@voltjia voltjia force-pushed the feat/nvidia-ntops-rms-norm branch from fa89de9 to 0ad2354 Compare May 20, 2026 11:23
_SUPPORTED_OPS = ("rms_norm",)


def _import_ninetoothed(source_dir):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

到底为啥需要这个 helper?去掉它,直接在 top-level import ninetoothed 就行。

Comment thread CMakeLists.txt

option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处不提及 NVIDIA,因为九齿的目标是跨平台,只是目前可能只暴露了 cuda caller,所以跟 PyTorch 的对齐即可。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

九齿的定位在算子库中应该跟 PyTorch 差不多,都可以接入到后端里,所以在文件结构上应该跟 PyTorch 平行,而不是放在 cuda 下,现在的九齿可能只有 cuda 这个 caller,但是生成的接口是一致的,只要后期增多了支持,就可以跨平台,跟 PyTorch 一样。

#ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_
#define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_

#ifdef WITH_NINETOOTHED
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就像评论 https://github.com/InfiniTensor/InfiniOps/pull/616/changes#r3285560677 所说,九齿应该是与 PyTorch 等对应的后端,所以是跟 torch 差不多的文件架构,而咱们算子库都是靠 build system 和脚本来确定最终产物,所以不要在 src 里面的文件使用 WITH_NINETOOTHED 这种类似的宏。事实上,在 C++ 中,我们应当尽量少地使用宏。

#include "rms_norm/infiniops_ninetoothed_rms_norm.h"

#ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE
#define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ 中尽量不使用宏,尤其是这种可以被 constexpr 或者 const 替代的情况。

Comment thread src/native/ninetoothed/tensor.h Outdated
};
}

template <typename NineToothedTensor, typename T>
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个模板参数的意义是什么?暂时是冗余的。如无必要,吴增实体。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理:删掉了原来的 FromScalar<NineToothedTensor> 函数模板,标量也统一通过 ninetoothed::Tensor(value, empty_shape, empty_strides) 包装后传给生成 launcher。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这里有一套方案,可以看看是不是更好一些:我们直接提供一个 infini::ops::ninetoothed::Tensor 类,这个类里面去定义从 infini::ops::Tensor 或者 scalar 到它的 implicit conversion。这样用起来会不会更方便一些?也请考虑一些可能的风险,综合评判是否这么做。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已按这个方向调整:现在提供 infini::ops::ninetoothed::Tensor,负责从 InfiniOps Tensor、标量或自定义 shape/stride 视图适配到 NineToothed launcher 参数。一个实现上的取舍是:公共头不直接依赖某个 op 生成出的 NineToothedTensor 定义,因为这个类型来自生成 header;因此转换在调用点按目标参数类型延迟实例化。这样可以保留 ninetoothed::Tensor(input) 的简洁使用方式,同时避免 include 顺序和 op-specific 生成头泄漏到公共适配层。

Comment thread CMakeLists.txt Outdated

set(${out_var} "" PARENT_SCOPE)
endfunction()
# NineToothed code generation configuration.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请在后面创建一个关于 WITH_NINETOOTHEDif 吧,把这些 set 放到这个分支里吧。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理:NINETOOTHED_*INFINIOPS_NINETOOTHED_* cache 配置现在都放在后面的 if(WITH_NINETOOTHED) 分支里,默认不开启时不再提前暴露这些配置项。

Comment thread tests/test_generate_ninetoothed_ops.py Outdated
import sys
import tempfile
import types
import unittest
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么引入了 unittest?请统一使用 pytest,与其他测试保持一致。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理:测试已改成 pytest 风格的普通测试函数,使用 monkeypatch,不再引入 unittestunittest.mock


namespace infini::ops {

namespace detail {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看此处 detail 内部的函数比较少,且每个函数的内容也很少,可以考虑直接放在 Operator<RmsNorm, Device::Type::kNvidia, 9>::operator() 里面,暂时不单独抽成独立的 helper 了。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理:去掉了 detail namespace 中的两个小 helper,把 weight 扩展、dtype index 和 generated launcher 调用都放回 Operator<RmsNorm, Device::Type::kNvidia, 9>::operator() 内。当前逻辑比较短,内联后更直接。

@voltjia voltjia force-pushed the feat/nvidia-ntops-rms-norm branch from 0ad2354 to eff11f2 Compare May 23, 2026 09:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant